import argparse
import os
import torch
from tqdm import tqdm
import time
from utils import setup_seed, save_checkpoint, label2class
from dataset.dataloader_hw import get_ddt_dataloader
from dataset.hw2 import HW
from model import PointPillarsV2
from loss import Loss
from torch.utils.tensorboard import SummaryWriter
import torch.multiprocessing as mp
import torch.distributed as dist
from wwengine.config import Config
from config.env import project_path
os.environ['MASTER_ADDR'] = 'localhost'
os.environ['MASTER_PORT'] = '12355'

# 获取工程路径,之后所有路径均为相对路径,方便在不同设备上跑代码
ProjectPath = project_path()

def save_summary(writer, loss_dict, global_step, tag, lr=None, momentum=None):
    for k, v in loss_dict.items():
        writer.add_scalar(f'{tag}/{k}', v, global_step)
    if lr is not None:
        writer.add_scalar('lr', lr, global_step)
    if momentum is not None:
        writer.add_scalar('momentum', momentum, global_step)


def main(rank, args):
    # 初始化
    dist.init_process_group("gloo", rank=rank, world_size=args.world_size)
    torch.cuda.set_device(rank)
    setup_seed()
    cfg = Config.fromfile(os.path.join(ProjectPath, 'config/default.json'))

    train_dataset = HW(cfg=cfg['DATASETS'],
                          split='train')
    val_dataset = HW(cfg=cfg['DATASETS'],
                        split='val')

    Label2Class, args.nclasses = label2class(cfg['DATASETS']['CLASS2LABEL'])

    train_sampler, train_dataloader = get_ddt_dataloader(dataset=train_dataset,
                                      batch_size=args.batch_size,
                                      num_workers=args.num_workers,
                                      shuffle=True)
    val_sampler, val_dataloader = get_ddt_dataloader(dataset=val_dataset,
                                      batch_size=args.batch_size,
                                      num_workers=args.num_workers,
                                      shuffle=False)


    pointpillars = PointPillarsV2(nclasses=args.nclasses,
                                point_cloud_range=cfg['DATASETS'].try_get('PTS_RANGE'),
                                voxel_size=cfg['MODEL'].try_get('Voxel_Size'),
                                use_intensity=args.use_intensity).to(rank)

    loss_func = Loss()

    max_iters = len(train_dataloader) * args.max_epoch
    init_lr = args.init_lr
    optimizer = torch.optim.AdamW(params=pointpillars.parameters(),
                                  lr=init_lr,
                                  betas=(0.95, 0.99),
                                  weight_decay=0.01)
    scheduler = torch.optim.lr_scheduler.OneCycleLR(optimizer,
                                                    max_lr=init_lr * 10,
                                                    total_steps=max_iters,
                                                    pct_start=0.4,
                                                    anneal_strategy='cos',
                                                    cycle_momentum=True,
                                                    base_momentum=0.95 * 0.895,
                                                    max_momentum=0.95,
                                                    div_factor=10)
    time_str = time.strftime('%Y-%m-%d-%H-%M')  # 时间戳字符串
    saved_logs_path = os.path.join(args.saved_path, 'summary')
    os.makedirs(saved_logs_path, exist_ok=True)
    writer = SummaryWriter(saved_logs_path)
    saved_ckpt_path = os.path.join(args.saved_path, 'Time_{}/checkpoints'.format(time_str))
    os.makedirs(saved_ckpt_path, exist_ok=True)

    epoch_start = 0
    # 加载模型的预训练参数
    if os.path.exists(args.pretrained):
        print("=> loading model '{}'".format(args.pretrained))
        checkpoint = torch.load(args.pretrained)
        if args.world_size == 1 :
            pointpillars.module.load_state_dict(checkpoint['state_dict'])
        else:
            pointpillars.load_state_dict(checkpoint['state_dict'])
        # optimizer.state = checkpoint['optimizer']['state']
        # optimizer.param_groups = checkpoint['optimizer']['param_groups']
        epoch_start = checkpoint['epoch'] + 1

    pointpillars = torch.nn.parallel.DistributedDataParallel(pointpillars, device_ids=[rank])

    for epoch in range(epoch_start, args.max_epoch):
        print('=' * 20, epoch, '=' * 20)
        train_step, val_step = 0, 0
        for i, data_dict in enumerate(tqdm(train_dataloader)):
            if data_dict['batched_labels'][0].shape[0] == 0:
                continue
            # move the tensors to the cuda
            for key in data_dict:
                for j, item in enumerate(data_dict[key]):
                    if torch.is_tensor(item):
                        data_dict[key][j] = data_dict[key][j].to(rank)

            optimizer.zero_grad()

            batched_pts = data_dict['batched_pts']
            batched_gt_bboxes = data_dict['batched_gt_bboxes']
            batched_labels = data_dict['batched_labels']

            # for i in range(len(batched_gt_bboxes)):
            #     batched_pts[i].to(rank).to(torch.float32)
            # for i in range(len(batched_labels)):
            #     batched_pts[i].to(rank).to(torch.float32)
            bbox_cls_pred, bbox_pred, bbox_dir_cls_pred, anchor_target_dict = \
                pointpillars(batched_pts=batched_pts,
                             mode='train',
                             batched_gt_bboxes=batched_gt_bboxes,
                             batched_gt_labels=batched_labels)

            bbox_cls_pred = bbox_cls_pred.permute(0, 2, 3, 1).reshape(-1, args.nclasses)
            bbox_pred = bbox_pred.permute(0, 2, 3, 1).reshape(-1, 7)
            bbox_dir_cls_pred = bbox_dir_cls_pred.permute(0, 2, 3, 1).reshape(-1, 2)

            batched_bbox_labels = anchor_target_dict['batched_labels'].reshape(-1)
            batched_label_weights = anchor_target_dict['batched_label_weights'].reshape(-1)
            batched_bbox_reg = anchor_target_dict['batched_bbox_reg'].reshape(-1, 7)
            # batched_bbox_reg_weights = anchor_target_dict['batched_bbox_reg_weights'].reshape(-1)
            batched_dir_labels = anchor_target_dict['batched_dir_labels'].reshape(-1)
            # batched_dir_labels_weights = anchor_target_dict['batched_dir_labels_weights'].reshape(-1)

            pos_idx = (batched_bbox_labels >= 0) & (batched_bbox_labels < args.nclasses)
            bbox_pred = bbox_pred[pos_idx]
            batched_bbox_reg = batched_bbox_reg[pos_idx]
            # sin(a - b) = sin(a)*cos(b) - cos(a)*sin(b)
            bbox_pred[:, -1] = torch.sin(bbox_pred[:, -1].clone()) * torch.cos(batched_bbox_reg[:, -1].clone())
            batched_bbox_reg[:, -1] = torch.cos(bbox_pred[:, -1].clone()) * torch.sin(batched_bbox_reg[:, -1].clone())
            bbox_dir_cls_pred = bbox_dir_cls_pred[pos_idx]
            batched_dir_labels = batched_dir_labels[pos_idx]

            num_cls_pos = (batched_bbox_labels < args.nclasses).sum()
            bbox_cls_pred = bbox_cls_pred[batched_label_weights > 0]
            batched_bbox_labels[batched_bbox_labels < 0] = args.nclasses
            batched_bbox_labels = batched_bbox_labels[batched_label_weights > 0]

            loss_dict = loss_func(bbox_cls_pred=bbox_cls_pred,
                                  bbox_pred=bbox_pred,
                                  bbox_dir_cls_pred=bbox_dir_cls_pred,
                                  batched_labels=batched_bbox_labels,
                                  num_cls_pos=num_cls_pos,
                                  batched_bbox_reg=batched_bbox_reg,
                                  batched_dir_labels=batched_dir_labels)

            loss = loss_dict['total_loss']
            loss.backward()
            # torch.nn.utils.clip_grad_norm_(pointpillars.parameters(), max_norm=35)
            optimizer.step()
            scheduler.step()

            # print("loss is {}".format(loss))

            global_step = epoch * len(train_dataloader) + train_step + 1

            if global_step % args.log_freq == 0 and rank == 0:
                save_summary(writer, loss_dict, global_step, 'train',
                             lr=optimizer.param_groups[0]['lr'],
                             momentum=optimizer.param_groups[0]['betas'][0])
            train_step += 1
        if (epoch + 1) % args.ckpt_freq_epoch == 0 and rank == 0:
            # torch.save(pointpillars.state_dict(), os.path.join(saved_ckpt_path, f'epoch_{epoch+1}.pth'))
            save_checkpoint(epoch=epoch, name="PointPillar_HW", model=pointpillars, optimizer=optimizer,
                            output_dir=saved_ckpt_path, is_best=False)

        if epoch % args.val_freq_epoch is not 0:
            continue
        pointpillars.eval()
        with torch.no_grad():
            for i, data_dict in enumerate(tqdm(val_dataloader)):
                if data_dict['batched_labels'][0].shape[0] == 0:
                    continue
                # move the tensors to the cuda
                for key in data_dict:
                    for j, item in enumerate(data_dict[key]):
                        if torch.is_tensor(item):
                            data_dict[key][j] = data_dict[key][j].to(rank)

                batched_pts = data_dict['batched_pts']
                for i in range(len(batched_pts)):
                    batched_pts[i] = batched_pts[i].to(torch.float32).to(rank)
                # batched_pts = data_dict['batched_pts'].to(torch.float32).to(rank)
                batched_gt_bboxes = data_dict['batched_gt_bboxes']
                batched_labels = data_dict['batched_labels']
                bbox_cls_pred, bbox_pred, bbox_dir_cls_pred, anchor_target_dict = \
                    pointpillars(batched_pts=batched_pts,
                                 mode='train',
                                 batched_gt_bboxes=batched_gt_bboxes,
                                 batched_gt_labels=batched_labels)

                bbox_cls_pred = bbox_cls_pred.permute(0, 2, 3, 1).reshape(-1, args.nclasses)
                bbox_pred = bbox_pred.permute(0, 2, 3, 1).reshape(-1, 7)
                bbox_dir_cls_pred = bbox_dir_cls_pred.permute(0, 2, 3, 1).reshape(-1, 2)

                batched_bbox_labels = anchor_target_dict['batched_labels'].reshape(-1)
                batched_label_weights = anchor_target_dict['batched_label_weights'].reshape(-1)
                batched_bbox_reg = anchor_target_dict['batched_bbox_reg'].reshape(-1, 7)
                # batched_bbox_reg_weights = anchor_target_dict['batched_bbox_reg_weights'].reshape(-1)
                batched_dir_labels = anchor_target_dict['batched_dir_labels'].reshape(-1)
                # batched_dir_labels_weights = anchor_target_dict['batched_dir_labels_weights'].reshape(-1)

                pos_idx = (batched_bbox_labels >= 0) & (batched_bbox_labels < args.nclasses)
                bbox_pred = bbox_pred[pos_idx]
                batched_bbox_reg = batched_bbox_reg[pos_idx]
                # sin(a - b) = sin(a)*cos(b) - cos(a)*sin(b)
                bbox_pred[:, -1] = torch.sin(bbox_pred[:, -1]) * torch.cos(batched_bbox_reg[:, -1])
                batched_bbox_reg[:, -1] = torch.cos(bbox_pred[:, -1]) * torch.sin(batched_bbox_reg[:, -1])
                bbox_dir_cls_pred = bbox_dir_cls_pred[pos_idx]
                batched_dir_labels = batched_dir_labels[pos_idx]

                num_cls_pos = (batched_bbox_labels < args.nclasses).sum()
                bbox_cls_pred = bbox_cls_pred[batched_label_weights > 0]
                batched_bbox_labels[batched_bbox_labels < 0] = args.nclasses
                batched_bbox_labels = batched_bbox_labels[batched_label_weights > 0]

                loss_dict = loss_func(bbox_cls_pred=bbox_cls_pred,
                                      bbox_pred=bbox_pred,
                                      bbox_dir_cls_pred=bbox_dir_cls_pred,
                                      batched_labels=batched_bbox_labels,
                                      num_cls_pos=num_cls_pos,
                                      batched_bbox_reg=batched_bbox_reg,
                                      batched_dir_labels=batched_dir_labels)

                print("total_loss is {}".format(loss_dict['total_loss']))

                global_step = epoch * len(val_dataloader) + val_step + 1
                if global_step % args.log_freq == 0 and rank == 0:
                    save_summary(writer, loss_dict, global_step, 'val')
                val_step += 1
        pointpillars.train()


if __name__ == '__main__':
    parser = argparse.ArgumentParser(description='Configuration Parameters')
    parser.add_argument('--saved_path', default='pillar_logs')
    parser.add_argument('--pretrained', default='')  # 预训练参数路径
    parser.add_argument('--use_intensity', type=bool, default=False)  # 是否训练强度
    parser.add_argument('--batch_size', type=int, default=1)
    parser.add_argument('--num_workers', type=int, default=4)
    parser.add_argument('--nclasses', type=int, default=4)
    parser.add_argument('--init_lr', type=float, default=0.00025)
    parser.add_argument('--max_epoch', type=int, default=160)
    parser.add_argument('--log_freq', type=int, default=8)
    parser.add_argument('--ckpt_freq_epoch', type=int, default=5)
    parser.add_argument('--val_freq_epoch', type=int, default=5)
    parser.add_argument('--world_size', type=int, default=1)
    args = parser.parse_args()

    mp.spawn(main,
        args=(args,),
        nprocs=args.world_size,
        join=True)